# -*- coding: utf-8 -*-  
'''
初始化cpt2.0

@author: luoyi
Created on 2021年4月7日
'''
import utils.conf as conf
import data.dataset_cpt2 as ds_cpt2


#    写入训练集数据
print('开始写入train数据集...')
pret_iterator_train = ds_cpt2.pre_training_iterator(inputs_path=conf.DATASET.get_in_train(), 
                                                    labels_path=conf.DATASET.get_label_train(), 
                                                    count=conf.DATASET.get_count_train())
ds_cpt2.save_pre_training_dataset(pret_iterator=pret_iterator_train, 
                                  tfrecord_dir=conf.GPT2.get_pre_training_tfrecord_train(), 
                                  limit=conf.GPT2.get_pre_training_tfrecord_limit())
print('写入train数据集成功.')


#    写入验证集数据
print('开始写入val数据集...')
pret_iterator_val = ds_cpt2.pre_training_iterator(inputs_path=conf.DATASET.get_in_val(), 
                                                  labels_path=conf.DATASET.get_label_val(), 
                                                  count=conf.DATASET.get_count_val())
ds_cpt2.save_pre_training_dataset(pret_iterator=pret_iterator_val, 
                                  tfrecord_dir=conf.GPT2.get_pre_training_tfrecord_val(), 
                                  limit=conf.GPT2.get_pre_training_tfrecord_limit())
print('写入val数据集成功.')
